{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "# Transfer learning & fine-tuning\n",
    "\n",
    "**Author:** [fchollet](https://twitter.com/fchollet)<br>\n",
    "**Date created:** 2020/04/15<br>\n",
    "**Last modified:** 2020/05/12<br>\n",
    "**Description:** Complete guide to transfer learning & fine-tuning in Keras."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Setup\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "from tensorflow import keras\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Introduction\n",
    "\n",
    "**Transfer learning** consists of taking features learned on one problem, and\n",
    "leveraging them on a new, similar problem. For instance, features from a model that has\n",
    "learned to identify racoons may be useful to kick-start a model meant to identify\n",
    " tanukis.\n",
    "\n",
    "Transfer learning is usually done for tasks where your dataset has too little data to\n",
    " train a full-scale model from scratch.\n",
    "\n",
    "The most common incarnation of transfer learning in the context of deep learning is the\n",
    " following worfklow:\n",
    "\n",
    "1. Take layers from a previously trained model.\n",
    "2. Freeze them, so as to avoid destroying any of the information they contain during\n",
    " future training rounds.\n",
    "3. Add some new, trainable layers on top of the frozen layers. They will learn to turn\n",
    " the old features into predictions on a  new dataset.\n",
    "4. Train the new layers on your dataset.\n",
    "\n",
    "A last, optional step, is **fine-tuning**, which consists of unfreezing the entire\n",
    "model you obtained above (or part of it), and re-training it on the new data with a\n",
    "very low learning rate. This can potentially achieve meaningful improvements, by\n",
    " incrementally adapting the pretrained features to the new data.\n",
    "\n",
    "First, we will go over the Keras `trainable` API in detail, which underlies most\n",
    " transfer learning & fine-tuning workflows.\n",
    "\n",
    "Then, we'll demonstrate the typical workflow by taking a model pretrained on the\n",
    "ImageNet dataset, and retraining it on the Kaggle \"cats vs dogs\" classification\n",
    " dataset.\n",
    "\n",
    "This is adapted from\n",
    "[Deep Learning with Python](https://www.manning.com/books/deep-learning-with-python)\n",
    " and the 2016 blog post\n",
    "[\"building powerful image classification models using very little\n",
    " data\"](https://blog.keras.io/building-powerful-image-classification-models-using-very-little-data.html).\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Freezing layers: understanding the `trainable` attribute\n",
    "\n",
    "Layers & models have three weight attributes:\n",
    "\n",
    "- `weights` is the list of all weights variables of the layer.\n",
    "- `trainable_weights` is the list of those that are meant to be updated (via gradient\n",
    " descent) to minimize the loss during training.\n",
    "- `non_trainable_weights` is the list of those that aren't meant to be trained.\n",
    " Typically they are updated by the model during the forward pass.\n",
    "\n",
    "**Example: the `Dense` layer has 2 trainable weights (kernel & bias)**\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "layer = keras.layers.Dense(3)\n",
    "layer.build((None, 4))  # Create the weights\n",
    "\n",
    "print(\"weights:\", len(layer.weights))\n",
    "print(\"trainable_weights:\", len(layer.trainable_weights))\n",
    "print(\"non_trainable_weights:\", len(layer.non_trainable_weights))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "In general, all weights are trainable weights. The only built-in layer that has\n",
    "non-trainable weights is the `BatchNormalization` layer. It uses non-trainable weights\n",
    " to keep track of the mean and variance of its inputs during training.\n",
    "To learn how to use non-trainable weights in your own custom layers, see the\n",
    "[guide to writing new layers from scratch](making_new_layers_and_models_via_subclassing).\n",
    "\n",
    "**Example: the `BatchNormalization` layer has 2 trainable weights and 2 non-trainable\n",
    " weights**\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "layer = keras.layers.BatchNormalization()\n",
    "layer.build((None, 4))  # Create the weights\n",
    "\n",
    "print(\"weights:\", len(layer.weights))\n",
    "print(\"trainable_weights:\", len(layer.trainable_weights))\n",
    "print(\"non_trainable_weights:\", len(layer.non_trainable_weights))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "Layers & models also feature a boolean attribute `trainable`. Its value can be changed.\n",
    "Setting `layer.trainable` to `False` moves all the layer's weights from trainable to\n",
    "non-trainable.  This is called \"freezing\" the layer: the state of a frozen layer won't\n",
    "be updated during training (either when training with `fit()` or when training with\n",
    " any custom loop that relies on `trainable_weights` to apply gradient updates).\n",
    "\n",
    "**Example: setting `trainable` to `False`**\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "layer = keras.layers.Dense(3)\n",
    "layer.build((None, 4))  # Create the weights\n",
    "layer.trainable = False  # Freeze the layer\n",
    "\n",
    "print(\"weights:\", len(layer.weights))\n",
    "print(\"trainable_weights:\", len(layer.trainable_weights))\n",
    "print(\"non_trainable_weights:\", len(layer.non_trainable_weights))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "When a trainable weight becomes non-trainable, its value is no longer updated during\n",
    " training.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "# Make a model with 2 layers\n",
    "layer1 = keras.layers.Dense(3, activation=\"relu\")\n",
    "layer2 = keras.layers.Dense(3, activation=\"sigmoid\")\n",
    "model = keras.Sequential([keras.Input(shape=(3,)), layer1, layer2])\n",
    "\n",
    "# Freeze the first layer\n",
    "layer1.trainable = False\n",
    "\n",
    "# Keep a copy of the weights of layer1 for later reference\n",
    "initial_layer1_weights_values = layer1.get_weights()\n",
    "\n",
    "# Train the model\n",
    "model.compile(optimizer=\"adam\", loss=\"mse\")\n",
    "model.fit(np.random.random((2, 3)), np.random.random((2, 3)))\n",
    "\n",
    "# Check that the weights of layer1 have not changed during training\n",
    "final_layer1_weights_values = layer1.get_weights()\n",
    "np.testing.assert_allclose(\n",
    "    initial_layer1_weights_values[0], final_layer1_weights_values[0]\n",
    ")\n",
    "np.testing.assert_allclose(\n",
    "    initial_layer1_weights_values[1], final_layer1_weights_values[1]\n",
    ")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "Do not confuse the `layer.trainable` attribute with the argument `training` in\n",
    "`layer.__call__()` (which controls whether the layer should run its forward pass in\n",
    " inference mode or training mode). For more information, see the\n",
    "[Keras FAQ](\n",
    "  https://keras.io/getting_started/faq/#whats-the-difference-between-the-training-argument-in-call-and-the-trainable-attribute).\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Recursive setting of the `trainable` attribute\n",
    "\n",
    "If you set `trainable = False` on a model or on any layer that has sublayers,\n",
    "all children layers become non-trainable as well.\n",
    "\n",
    "**Example:**\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "inner_model = keras.Sequential(\n",
    "    [\n",
    "        keras.Input(shape=(3,)),\n",
    "        keras.layers.Dense(3, activation=\"relu\"),\n",
    "        keras.layers.Dense(3, activation=\"relu\"),\n",
    "    ]\n",
    ")\n",
    "\n",
    "model = keras.Sequential(\n",
    "    [keras.Input(shape=(3,)), inner_model, keras.layers.Dense(3, activation=\"sigmoid\"),]\n",
    ")\n",
    "\n",
    "model.trainable = False  # Freeze the outer model\n",
    "\n",
    "assert inner_model.trainable == False  # All layers in `model` are now frozen\n",
    "assert inner_model.layers[0].trainable == False  # `trainable` is propagated recursively\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## The typical transfer-learning workflow\n",
    "\n",
    "This leads us to how a typical transfer learning workflow can be implemented in Keras:\n",
    "\n",
    "1. Instantiate a base model and load pre-trained weights into it.\n",
    "2. Freeze all layers in the base model by setting `trainable = False`.\n",
    "3. Create a new model on top of the output of one (or several) layers from the base\n",
    " model.\n",
    "4. Train your new model on your new dataset.\n",
    "\n",
    "Note that an alternative, more lightweight workflow could also be:\n",
    "\n",
    "1. Instantiate a base model and load pre-trained weights into it.\n",
    "2. Run your new dataset through it and record the output of one (or several) layers\n",
    " from the base model. This is called **feature extraction**.\n",
    "3. Use that output as input data for a new, smaller model.\n",
    "\n",
    "A key advantage of that second workflow is that you only run the base model once one\n",
    " your data, rather than once per epoch of training. So it's a lot faster & cheaper.\n",
    "\n",
    "An issue with that second workflow, though, is that it doesn't allow you to dynamically\n",
    "modify the input data of your new model during training, which is required when doing\n",
    "data augmentation, for instance. Transfer learning is typically used for tasks when\n",
    "your new dataset has too little data to train a full-scale model from scratch, and in\n",
    "such scenarios data augmentation is very important. So in what follows, we will focus\n",
    " on the first workflow.\n",
    "\n",
    "Here's what the first workflow looks like in Keras:\n",
    "\n",
    "First, instantiate a base model with pre-trained weigts.\n",
    "\n",
    "```python\n",
    "base_model = keras.applications.Xception(\n",
    "    weights='imagenet',  # Load weights pre-trained on ImageNet.\n",
    "    input_shape=(150, 150, 3),\n",
    "    include_top=False)  # Do not include the ImageNet classifier at the top.\n",
    "```\n",
    "\n",
    "Then, freeze the base model.\n",
    "\n",
    "```python\n",
    "base_model.trainable = False\n",
    "```\n",
    "\n",
    "Create a new model on top.\n",
    "\n",
    "```python\n",
    "inputs = keras.Input(shape=(150, 150, 3))\n",
    "# We make sure that the base_model is running in inference mode here,\n",
    "# by passing `training=False`. This is important for fine-tuning, as you will\n",
    "# learn in a few paragraphs.\n",
    "x = base_model(inputs, training=False)\n",
    "# Convert features of shape `base_model.output_shape[1:]` to vectors\n",
    "x = keras.layers.GlobalAveragePooling2D()(x)\n",
    "# A Dense classifier with a single unit (binary classification)\n",
    "outputs = keras.layers.Dense(1)(x)\n",
    "model = keras.Model(inputs, outputs)\n",
    "```\n",
    "\n",
    "Train the model on new data.\n",
    "\n",
    "```python\n",
    "model.compile(optimizer=keras.optimizers.Adam(),\n",
    "              loss=keras.losses.BinaryCrossentropy(from_logits=True),\n",
    "              metrics=[keras.metrics.BinaryAccuracy()])\n",
    "model.fit(new_dataset, epochs=20, callbacks=..., validation_data=...)\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Fine-tuning\n",
    "\n",
    "Once your model has converged on the new data, you can try to unfreeze all or part of\n",
    " the base model and retrain the whole model end-to-end with a very low learning rate.\n",
    "\n",
    "This is an optional last step that can potentially give you incremental improvements.\n",
    " It could also potentially lead to quick overfitting -- keep that in mind.\n",
    "\n",
    "It is critical to only do this step *after* the model with frozen layers has been\n",
    "trained to convergence. If you mix randomly-initialized trainable layers with\n",
    "trainable layers that hold pre-trained features, the randomly-initialized layers will\n",
    "cause very large gradient updates during training, which will destroy your pre-trained\n",
    " features.\n",
    "\n",
    "It's also critical to use a very low learning rate at this stage, because\n",
    "you are training a much larger model than in the first round of training, on a dataset\n",
    " that is typically very small.\n",
    "As a result, you are at risk of overfitting very quickly if you apply large weight\n",
    " updates. Here, you only want to readapt the pretrained weights in an incremental way.\n",
    "\n",
    "This is how to implement fine-tuning of the whole base model:\n",
    "\n",
    "```python\n",
    "# Unfreeze the base model\n",
    "base_model.trainable = True\n",
    "\n",
    "# It's important to recompile your model after you make any changes\n",
    "# to the `trainable` attribute of any inner layer, so that your changes\n",
    "# are take into account\n",
    "model.compile(optimizer=keras.optimizers.Adam(1e-5),  # Very low learning rate\n",
    "              loss=keras.losses.BinaryCrossentropy(from_logits=True),\n",
    "              metrics=[keras.metrics.BinaryAccuracy()])\n",
    "\n",
    "# Train end-to-end. Be careful to stop before you overfit!\n",
    "model.fit(new_dataset, epochs=10, callbacks=..., validation_data=...)\n",
    "```\n",
    "\n",
    "**Important note about `compile()` and `trainable`**\n",
    "\n",
    "Calling `compile()` on a model is meant to \"freeze\" the behavior of that model. This\n",
    " implies that the `trainable`\n",
    "attribute values at the time the model is compiled should be preserved throughout the\n",
    " lifetime of that model,\n",
    "until `compile` is called again. Hence, if you change any `trainable` value, make sure\n",
    " to call `compile()` again on your\n",
    "model for your changes to be taken into account.\n",
    "\n",
    "**Important notes about `BatchNormalization` layer**\n",
    "\n",
    "Many image models contain `BatchNormalization` layers. That layer is a special case on\n",
    " every imaginable count. Here are a few things to keep in mind.\n",
    "\n",
    "- `BatchNormalization` contains 2 non-trainable weights that get updated during\n",
    "training. These are the variables tracking the mean and variance of the inputs.\n",
    "- When you set `bn_layer.trainable = False`, the `BatchNormalization` layer will\n",
    "run in inference mode, and will not update its mean & variance statistics. This is not\n",
    "the case for other layers in general, as\n",
    "[weight trainability & inference/training modes are two orthogonal concepts](\n",
    "  https://keras.io/getting_started/faq/#whats-the-difference-between-the-training-argument-in-call-and-the-trainable-attribute).\n",
    "But the two are tied in the case of the `BatchNormalization` layer.\n",
    "- When you unfreeze a model that contains `BatchNormalization` layers in order to do\n",
    "fine-tuning, you should keep the `BatchNormalization` layers in inference mode by\n",
    " passing `training=False` when calling the base model.\n",
    "Otherwise the updates applied to the non-trainable weights will suddenly destroy\n",
    "what the model the model has learned.\n",
    "\n",
    "You'll see this pattern in action in the end-to-end example at the end of this guide.\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Transfer learning & fine-tuning with a custom training loop\n",
    "\n",
    "If instead of `fit()`, you are using your own low-level training loop, the workflow\n",
    "stays essentially the same. You should be careful to only take into account the list\n",
    " `model.trainable_weights` when applying gradient updates:\n",
    "\n",
    "```python\n",
    "# Create base model\n",
    "base_model = keras.applications.Xception(\n",
    "    weights='imagenet',\n",
    "    input_shape=(150, 150, 3),\n",
    "    include_top=False)\n",
    "# Freeze base model\n",
    "base_model.trainable = False\n",
    "\n",
    "# Create new model on top.\n",
    "inputs = keras.Input(shape=(150, 150, 3))\n",
    "x = base_model(inputs, training=False)\n",
    "x = keras.layers.GlobalAveragePooling2D()(x)\n",
    "outputs = keras.layers.Dense(1)(x)\n",
    "model = keras.Model(inputs, outputs)\n",
    "\n",
    "loss_fn = keras.losses.BinaryCrossentropy(from_logits=True)\n",
    "optimizer = keras.optimizers.Adam()\n",
    "\n",
    "# Iterate over the batches of a dataset.\n",
    "for inputs, targets in new_dataset:\n",
    "    # Open a GradientTape.\n",
    "    with tf.GradientTape() as tape:\n",
    "        # Forward pass.\n",
    "        predictions = model(inputs)\n",
    "        # Compute the loss value for this batch.\n",
    "        loss_value = loss_fn(targets, predictions)\n",
    "\n",
    "    # Get gradients of loss wrt the *trainable* weights.\n",
    "    gradients = tape.gradient(loss_value, model.trainable_weights)\n",
    "    # Update the weights of the model.\n",
    "    optimizer.apply_gradients(zip(gradients, model.trainable_weights))\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "Likewise for fine-tuning.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## An end-to-end example: fine-tuning an image classification model on a cats vs. dogs\n",
    " dataset\n",
    "\n",
    "To solidify these concepts, let's walk you through a concrete end-to-end transfer\n",
    "learning & fine-tuning example. We will load the Xception model, pre-trained on\n",
    " ImageNet, and use it on the Kaggle \"cats vs. dogs\" classification dataset.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "### Getting the data\n",
    "\n",
    "First, let's fetch the cats vs. dogs dataset using TFDS. If you have your own dataset,\n",
    "you'll probably want to use the utility\n",
    "`tf.keras.preprocessing.image_dataset_from_directory` to generate similar labeled\n",
    " dataset objects from a set of images on disk filed into class-specific folders.\n",
    "\n",
    "Tansfer learning is most useful when working with very small datases. To keep our\n",
    "dataset small, we will use 40% of the original training data (25,000 images) for\n",
    " training, 10% for validation, and 10% for testing.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "import tensorflow_datasets as tfds\n",
    "\n",
    "tfds.disable_progress_bar()\n",
    "\n",
    "train_ds, validation_ds, test_ds = tfds.load(\n",
    "    \"cats_vs_dogs\",\n",
    "    # Reserve 10% for validation and 10% for test\n",
    "    split=[\"train[:40%]\", \"train[40%:50%]\", \"train[50%:60%]\"],\n",
    "    as_supervised=True,  # Include labels\n",
    ")\n",
    "\n",
    "print(\"Number of training samples: %d\" % tf.data.experimental.cardinality(train_ds))\n",
    "print(\n",
    "    \"Number of validation samples: %d\" % tf.data.experimental.cardinality(validation_ds)\n",
    ")\n",
    "print(\"Number of test samples: %d\" % tf.data.experimental.cardinality(test_ds))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "These are the first 9 images in the training dataset -- as you can see, they're all\n",
    " different sizes.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "plt.figure(figsize=(10, 10))\n",
    "for i, (image, label) in enumerate(train_ds.take(9)):\n",
    "    ax = plt.subplot(3, 3, i + 1)\n",
    "    plt.imshow(image)\n",
    "    plt.title(int(label))\n",
    "    plt.axis(\"off\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "We can also see that label 1 is \"dog\" and label 0 is \"cat\".\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "### Standardizing the data\n",
    "\n",
    "Our raw images have a variety of sizes. In addition, each pixel consists of 3 integer\n",
    "values between 0 and 255 (RGB level values). This isn't a great fit for feeding a\n",
    " neural network. We need to do 2 things:\n",
    "\n",
    "- Standardize to a fixed image size. We pick 150x150.\n",
    "- Normalize pixel values between -1 and 1. We'll do this using a `Normalization` layer as\n",
    " part of the model itself.\n",
    "\n",
    "In general, it's a good practice to develop models that take raw data as input, as\n",
    "opposed to models that take already-preprocessed data. The reason being that, if your\n",
    "model expects preprocessed data, any time you export your model to use it elsewhere\n",
    "(in a web browser, in a mobile app), you'll need to reimplement the exact same\n",
    "preprocessing pipeline. This get very tricky very quickly. So we should do the least\n",
    " possible amount of preprocessing before hitting the model.\n",
    "\n",
    "Here, we'll do image resizing in the data pipeline (because a deep neural network can\n",
    "only process contiguous batches of data), and we'll do the input value scaling as part\n",
    " of the model, when we create it.\n",
    "\n",
    "Let's resize images to 150x150:\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "size = (150, 150)\n",
    "\n",
    "train_ds = train_ds.map(lambda x, y: (tf.image.resize(x, size), y))\n",
    "validation_ds = validation_ds.map(lambda x, y: (tf.image.resize(x, size), y))\n",
    "test_ds = test_ds.map(lambda x, y: (tf.image.resize(x, size), y))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "Besides, let's batch the data and use caching & prefetching to optimize loading speed.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "batch_size = 32\n",
    "\n",
    "train_ds = train_ds.cache().batch(batch_size).prefetch(buffer_size=10)\n",
    "validation_ds = validation_ds.cache().batch(batch_size).prefetch(buffer_size=10)\n",
    "test_ds = test_ds.cache().batch(batch_size).prefetch(buffer_size=10)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "### Using random data augmentation\n",
    "\n",
    "When you don't have a large image dataset, it's a good practice to artificially\n",
    " introduce sample diversity by applying random yet realistic transformations to\n",
    "the training images, such as random horizontal flipping or small random rotations. This\n",
    "helps expose the model to different aspects of the training data while slowing down\n",
    " overfitting.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "from tensorflow import keras\n",
    "from tensorflow.keras import layers\n",
    "\n",
    "data_augmentation = keras.Sequential(\n",
    "    [\n",
    "        layers.experimental.preprocessing.RandomFlip(\"horizontal\"),\n",
    "        layers.experimental.preprocessing.RandomRotation(0.1),\n",
    "    ]\n",
    ")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "Let's visualize what the first image of the first batch looks like after various random\n",
    " transformations:\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "for images, labels in train_ds.take(1):\n",
    "    plt.figure(figsize=(10, 10))\n",
    "    first_image = images[0]\n",
    "    for i in range(9):\n",
    "        ax = plt.subplot(3, 3, i + 1)\n",
    "        augmented_image = data_augmentation(\n",
    "            tf.expand_dims(first_image, 0), training=True\n",
    "        )\n",
    "        plt.imshow(augmented_image[0].numpy().astype(\"int32\"))\n",
    "        plt.title(int(labels[i]))\n",
    "        plt.axis(\"off\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Build a model\n",
    "\n",
    "Now let's built a model that follows the blueprint we've explained earlier.\n",
    "\n",
    "Note that:\n",
    "\n",
    "- We add a `Normalization` layer to scale input values (initially in the `[0, 255]`\n",
    " range) to the `[-1, 1]` range.\n",
    "- We add a `Dropout` layer before the classification layer, for regularization.\n",
    "- We make sure to pass `training=False` when calling the base model, so that\n",
    "it runs in inference mode, so that batchnorm statistics don't get updated\n",
    "even after we unfreeze the base model for fine-tuning.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "base_model = keras.applications.Xception(\n",
    "    weights=\"imagenet\",  # Load weights pre-trained on ImageNet.\n",
    "    input_shape=(150, 150, 3),\n",
    "    include_top=False,\n",
    ")  # Do not include the ImageNet classifier at the top.\n",
    "\n",
    "# Freeze the base_model\n",
    "base_model.trainable = False\n",
    "\n",
    "# Create new model on top\n",
    "inputs = keras.Input(shape=(150, 150, 3))\n",
    "x = data_augmentation(inputs)  # Apply random data augmentation\n",
    "\n",
    "# Pre-trained Xception weights requires that input be normalized\n",
    "# from (0, 255) to a range (-1., +1.), the normalization layer\n",
    "# does the following, outputs = (inputs - mean) / sqrt(var)\n",
    "norm_layer = keras.layers.experimental.preprocessing.Normalization()\n",
    "mean = np.array([127.5] * 3)\n",
    "var = mean ** 2\n",
    "# Scale inputs to [-1, +1]\n",
    "x = norm_layer(x)\n",
    "norm_layer.set_weights([mean, var])\n",
    "\n",
    "# The base model contains batchnorm layers. We want to keep them in inference mode\n",
    "# when we unfreeze the base model for fine-tuning, so we make sure that the\n",
    "# base_model is running in inference mode here.\n",
    "x = base_model(x, training=False)\n",
    "x = keras.layers.GlobalAveragePooling2D()(x)\n",
    "x = keras.layers.Dropout(0.2)(x)  # Regularize with dropout\n",
    "outputs = keras.layers.Dense(1)(x)\n",
    "model = keras.Model(inputs, outputs)\n",
    "\n",
    "model.summary()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Train the top layer\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "model.compile(\n",
    "    optimizer=keras.optimizers.Adam(),\n",
    "    loss=keras.losses.BinaryCrossentropy(from_logits=True),\n",
    "    metrics=[keras.metrics.BinaryAccuracy()],\n",
    ")\n",
    "\n",
    "epochs = 20\n",
    "model.fit(train_ds, epochs=epochs, validation_data=validation_ds)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Do a round of fine-tuning of the entire model\n",
    "\n",
    "Finally, let's unfreeze the base model and train the entire model end-to-end with a low\n",
    " learning rate.\n",
    "\n",
    "Importantly, although the base model becomes trainable, it is still running in\n",
    "inference mode since we passed `training=False` when calling it when we built the\n",
    "model. This means that the batch normalization layers inside won't update their batch\n",
    "statistics. If they did, they would wreck havoc on the representations learned by the\n",
    " model so far.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "# Unfreeze the base_model. Note that it keeps running in inference mode\n",
    "# since we passed `training=False` when calling it. This means that\n",
    "# the batchnorm layers will not update their batch statistics.\n",
    "# This prevents the batchnorm layers from undoing all the training\n",
    "# we've done so far.\n",
    "base_model.trainable = True\n",
    "model.summary()\n",
    "\n",
    "model.compile(\n",
    "    optimizer=keras.optimizers.Adam(1e-5),  # Low learning rate\n",
    "    loss=keras.losses.BinaryCrossentropy(from_logits=True),\n",
    "    metrics=[keras.metrics.BinaryAccuracy()],\n",
    ")\n",
    "\n",
    "epochs = 10\n",
    "model.fit(train_ds, epochs=epochs, validation_data=validation_ds)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "After 10 epochs, fine-tuning gains us a nice improvement here.\n"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "transfer_learning",
   "private_outputs": false,
   "provenance": [],
   "toc_visible": true
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}